# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch

try:
    from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
    from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2

    HAS_LIGHTLLM_KERNEL = True
except:
    print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
    HAS_LIGHTLLM_KERNEL = False


if HAS_LIGHTLLM_KERNEL:

    def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
        BLOCK_SEQ = 256
        batch_size = infer_state.batch_size
        max_len_in_batch = infer_state.max_len_in_batch

        calcu_shape1 = (batch_size, q_head_num, head_dim)

        if getattr(infer_state, "mid_o", None) is None:
            infer_state.mid_o = torch.empty(
                [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
                dtype=torch.float32,
                device="cuda",
            )
            infer_state.mid_o_logexpsum = torch.empty(
                [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
            )

        mid_o = infer_state.mid_o
        mid_o_logexpsum = infer_state.mid_o_logexpsum

        flash_decode_stage1(
            q.view(calcu_shape1),
            cache_k,
            cache_v,
            infer_state.block_loc,
            infer_state.seq_len,
            infer_state.max_len_in_batch,
            mid_o,
            mid_o_logexpsum,
            BLOCK_SEQ,
        )
        flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
